Skip to content

Conversation

@NicoGrande
Copy link
Collaborator

@NicoGrande NicoGrande commented Dec 2, 2025

Description

This PR finishes the work started by @gagika in #2767. Credits to @gagika for helping with this feature!

This PR adds the changes required to train_rl.py as well as other modules related to Tunix integration to allow for additional configurations needed for the MaxText on vLLM flow to be passed to Tunix.

More specifically, this PR adds vllm_additional_config and vllm_hf_config_path as new arguments such that these values can be pipelined to Tunix for RL.

Additionally, this PR makes some small modifications to tunix_adapter.py to allow for no-ops to be used as mappings when running RL using MaxText for vLLM.

Tests

Gemma3-4B:

Local (v6e-4 VM):

NEW_MODEL_DESIGN=True  HF_TOKEN=$HF_TOKEN TPU_BACKEND_TYPE=jax python3 -m src.MaxText.rl.train_rl src/MaxText/configs/rl.yml \
  model_name=gemma3-4b \
  tokenizer_path=google/gemma-3-4b-it \
  run_name=$WORKLOAD \
  base_output_directory=$OUTPUT_PATH \
  hf_access_token=$HF_TOKEN \
scan_layers=False \
load_parameters_path="gs://maxtext-gemma/unified/gemma3/4b/unscanned/2025-08-09-01-17/0/items" \
  vllm_hf_config_path=src/MaxText/integration/vllm/maxtext_vllm_adapter   vllm_additional_config='{"maxtext_config": {"model_name": "gemma3-4b", "max_prefill_predict_length": 28, "max_target_length": 32, "ici_tensor_parallelism": 4}}'

Output: logs

Qwen3-30B-A3B:

v5p-64 Cluster:

xpk workload create-pathways --workload $WORKLOAD_NAME --num-slices=1 --priority very-high --docker-image gcr.io/tpu-prod-env-multipod/maxtext-vllm-tpu-test:stable --cluster $CLUSTER_NAME --tpu-type=$TPU_TYPE --zone=$ZONE --project=$PROJECT_ID --command "cd /app && pip install --no-deps -e . && VLLM_TORCH_PROFILER_DIR=$OUTPUT_PATH MAXTEXT_PKG_DIR=/app/src/MaxText HF_TOKEN=$HF_TOKEN JAX_RANDOM_WEIGHTS=1 VLLM_ENABLE_V1_MULTIPROCESSING=0 NEW_MODEL_DESIGN=1 TPU_MIN_LOG_LEVEL=0 TF_CPP_MIN_LOG_LEVEL=0 TPU_STDERR_LOG_LEVEL=0 JAX_PLATFORMS=proxy,cpu JAX_BACKEND_TARGET=grpc://127.0.0.1:29000 ENABLE_PATHWAYS_PERSISTENCE=1 python3 -m src.MaxText.rl.train_rl src/MaxText/configs/rl.yml \
  model_name=qwen3-30b-a3b \
  tokenizer_path=Qwen/Qwen3-30B-A3B \
  run_name=$WORKLOAD \
  base_output_directory=$OUTPUT_PATH \
  hf_access_token=$HF_TOKEN \
  batch_size=16 \
  rollout_data_parallelism=4 \
  rollout_tensor_parallelism=4 \
  hbm_utilization_vllm=0.60 \
  scan_layers=False \
  load_parameters_path=$CHECKPOINT_PATH \
  allow_split_physical_axes=True \
  vllm_hf_config_path=src/MaxText/integration/vllm/maxtext_vllm_adapter vllm_additional_config='{maxtext_config: {model_name: qwen3-30b-a3b, allow_split_physical_axes: true, log_config: false}}'"

Output: logs

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

Copy link
Collaborator

@gagika gagika left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks

@NicoGrande NicoGrande force-pushed the nicogrande/maxtext-vllm-rl-integration branch 8 times, most recently from d82bf6f to 28835d7 Compare December 16, 2025 20:19
@NicoGrande NicoGrande force-pushed the nicogrande/maxtext-vllm-rl-integration branch 4 times, most recently from 472355f to 9631581 Compare December 18, 2025 00:46
@NicoGrande NicoGrande force-pushed the nicogrande/maxtext-vllm-rl-integration branch 4 times, most recently from ff2417e to 6b1618a Compare December 20, 2025 01:43
Fix formatting.

Refactor model creation and error handling in RL training

fix linting.

adding no-op mappings to tunix adapter.

removing kvcache init for vllm case.

latest updates from debugging.

adding null logical axis rules to adapter.

adding linting fixes.

fixing pyink

remove unused imports attentions test.

adding fixes.

addressing comments in evaluate rl.

set weight dtype to bf16 by default.

removing unecessary logical axis rules.

removing epath.

removing deprecated .value call
@NicoGrande NicoGrande force-pushed the nicogrande/maxtext-vllm-rl-integration branch from 6b1618a to e0e5a25 Compare December 20, 2025 03:06
Copy link
Collaborator

@A9isha A9isha left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you!

@copybara-service copybara-service bot merged commit c2574ab into main Dec 20, 2025
24 of 26 checks passed
@copybara-service copybara-service bot deleted the nicogrande/maxtext-vllm-rl-integration branch December 20, 2025 12:49
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants